{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e2893f93-72d9-40ea-93e9-be2d3f3f66ee",
   "metadata": {},
   "source": [
    "# Distributed ML model batch inference on data in DeltaLake\n",
    "\n",
    "In this tutorial, we showcase how to perform ML model batch inference on data in a DeltaLake table.\n",
    "\n",
    "This is a continuation of the previous tutorial on **local** batch inference, which is a great way to get started and make sure that your code is working before graduating to larger scales in a distributed batch inference workload. Make sure to give that a read before looking at this tutorial!\n",
    "\n",
    "To run this tutorial you will require AWS credentials to be correctly provisioned on your machine as all data is hosted in a requestor-pays bucket in AWS S3.\n",
    "\n",
    "Let's get started!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2be7773d",
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "CI = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83b289cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Skip this notebook execution in CI because it hits non-public buckets\n",
    "if CI:\n",
    "    import sys\n",
    "\n",
    "    sys.exit()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdc8e6da",
   "metadata": {},
   "source": [
    "# Going Distributed\n",
    "\n",
    "The first step (and most important for this demo!) is to switch our Daft runner to the Ray runner, and point it at a Ray cluster. This is super simple:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8e600443-3931-44f2-b814-0056e42da612",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DaftContext(_daft_execution_config=<daft.daft.PyDaftExecutionConfig object at 0x1039afc90>, _daft_planning_config=<daft.daft.PyDaftPlanningConfig object at 0x1039afc10>, _runner_config=_RayRunnerConfig(address=None, max_task_backlog=None), _disallow_set_runner=True, _runner=None)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import daft\n",
    "\n",
    "# If you have your own Ray cluster running, feel free to set this to that address!\n",
    "# RAY_ADDRESS = \"ray://localhost:10001\"\n",
    "RAY_ADDRESS = None\n",
    "\n",
    "daft.set_runner_ray(address=RAY_ADDRESS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fdf0722-eff4-485d-84e8-f4c74e79caca",
   "metadata": {},
   "source": [
    "Now, we run the same operations as before. The only difference is that instead of execution happening locally on the machine that's running this code, Daft will distribute the computation over your Ray cluster!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "569c6297-dcfd-4013-9a04-e3fc8f0ca315",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Feel free to tweak this variable to have the tutorial run on as many rows as you'd like!\n",
    "NUM_ROWS = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08ff4bf0-7b5f-4884-80d3-95d7b9005a8b",
   "metadata": {},
   "source": [
    "### Retrieving data\n",
    "\n",
    "We will be retrieving the data exactly the same way we did in the previous tutorial, with the same API and arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "83a76976-aed6-49ea-8c8e-1572947d93ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Provision Cloud Credentials\n",
    "import boto3\n",
    "\n",
    "import daft\n",
    "\n",
    "session = boto3.session.Session()\n",
    "creds = session.get_credentials()\n",
    "io_config = daft.io.IOConfig(\n",
    "    s3=daft.io.S3Config(\n",
    "        access_key=creds.secret_key,\n",
    "        key_id=creds.access_key,\n",
    "        session_token=creds.token,\n",
    "        region_name=\"us-west-2\",\n",
    "    )\n",
    ")\n",
    "\n",
    "# Retrieve data\n",
    "df = daft.read_deltalake(\"s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/\", io_config=io_config)\n",
    "\n",
    "# Prune data\n",
    "df = df.limit(NUM_ROWS)\n",
    "df = df.where(df[\"object\"].list.length() == 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8043c7e1-c350-449b-bd93-4a5ca93adc4d",
   "metadata": {},
   "source": [
    "### Splitting the data into more partitions\n",
    "\n",
    "We now split the data into more partitions for additional parallelism when performing our data processing in a **distributed** fashion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d553c284-5f4e-435b-a0f2-c35bf4fc09d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.into_partitions(16)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acc6f220-aaef-463c-ae51-014bddc14231",
   "metadata": {},
   "source": [
    "### Retrieving the images and preprocessing\n",
    "\n",
    "Now we continue with exactly the same code as in the local case for retrieving and preprocessing our images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "317e2778-4986-4993-ab4d-0426e5fee149",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Retrieve images and run preprocessing\n",
    "df = df.with_column(\n",
    "    \"image_url\", \"s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/\" + df[\"filename\"] + \".jpeg\"\n",
    ")\n",
    "df = df.with_column(\"image\", df[\"image_url\"].download().decode_image())\n",
    "df = df.with_column(\"image_resized_small\", df[\"image\"].resize(32, 32))\n",
    "df = df.with_column(\"image_resized_large\", df[\"image\"].resize(256, 256))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7dc7cb6a-9c03-4386-8f31-7e5559b370b3",
   "metadata": {},
   "source": [
    "### Running batch inference with a UDF\n",
    "\n",
    "Running the UDF is also exactly the same!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "29b21c64-026e-43bd-aea0-48ae3a452b7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-29 19:38:18,040\tINFO worker.py:1642 -- Started a local Ray instance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c76dafc0f26b4f4782089c7381e4baa0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "ScanWithTask-LocalLimit-LocalLimit-Project-Filter [Stage:3]:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4e8f82df0b664783bf61ca11686ee421",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FanoutSlices [Stage:2]:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "52508fcdaa9740549d46c050f410bc1d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Project-Project-Filter-Project-WriteFile [Stage:1]:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<table class=\"dataframe\">\n",
       "<thead><tr><th style=\"text-wrap: nowrap; max-width:192px; overflow:auto; text-align:left\">path<br />Utf8</th></tr></thead>\n",
       "<tbody>\n",
       "<tr><td><div style=\"text-align:left; max-width:192px; max-height:64px; overflow:auto\">my_results.parquet/8eb54f00-9537-4e28-ac85-e96a00a071d5-0.parquet</div></td></tr>\n",
       "<tr><td><div style=\"text-align:left; max-width:192px; max-height:64px; overflow:auto\">my_results.parquet/04ccf8fe-9777-4307-9e1f-916c8532ca1c-0.parquet</div></td></tr>\n",
       "<tr><td><div style=\"text-align:left; max-width:192px; max-height:64px; overflow:auto\">my_results.parquet/867fc77f-f730-4b53-8e9a-11ed5dc9b98f-0.parquet</div></td></tr>\n",
       "<tr><td><div style=\"text-align:left; max-width:192px; max-height:64px; overflow:auto\">my_results.parquet/e4645f7b-8a70-4ee8-8221-823777467a0a-0.parquet</div></td></tr>\n",
       "<tr><td><div style=\"text-align:left; max-width:192px; max-height:64px; overflow:auto\">my_results.parquet/dd41fced-6e6b-4ece-8e58-d0804311b4ff-0.parquet</div></td></tr>\n",
       "<tr><td><div style=\"text-align:left; max-width:192px; max-height:64px; overflow:auto\">my_results.parquet/c548e6f4-3c83-4f76-b7c5-821f81157720-0.parquet</div></td></tr>\n",
       "<tr><td><div style=\"text-align:left; max-width:192px; max-height:64px; overflow:auto\">my_results.parquet/28753019-9875-45a2-94b4-b7b9217492ca-0.parquet</div></td></tr>\n",
       "<tr><td><div style=\"text-align:left; max-width:192px; max-height:64px; overflow:auto\">my_results.parquet/f66ffaa6-cc2e-4328-8137-aa358244a8a3-0.parquet</div></td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "<small>(Showing first 8 of 16 rows)</small>\n",
       "</div>"
      ],
      "text/plain": [
       "╭────────────────────────────────╮\n",
       "│ path                           │\n",
       "│ ---                            │\n",
       "│ Utf8                           │\n",
       "╞════════════════════════════════╡\n",
       "│ my_results.parquet/8eb54f00-9… │\n",
       "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n",
       "│ my_results.parquet/04ccf8fe-9… │\n",
       "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n",
       "│ my_results.parquet/867fc77f-f… │\n",
       "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n",
       "│ my_results.parquet/e4645f7b-8… │\n",
       "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n",
       "│ my_results.parquet/dd41fced-6… │\n",
       "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n",
       "│ my_results.parquet/c548e6f4-3… │\n",
       "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n",
       "│ my_results.parquet/28753019-9… │\n",
       "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n",
       "│ my_results.parquet/f66ffaa6-c… │\n",
       "╰────────────────────────────────╯\n",
       "\n",
       "(Showing first 8 of 16 rows)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Run batch inference over the entire dataset\n",
    "import numpy as np\n",
    "import torch\n",
    "from torchvision.models import ResNet50_Weights, resnet50\n",
    "\n",
    "import daft\n",
    "\n",
    "\n",
    "@daft.udf(return_dtype=daft.DataType.string())\n",
    "class ClassifyImage:\n",
    "    def __init__(self):\n",
    "        weights = ResNet50_Weights.DEFAULT\n",
    "        self.model = resnet50(weights=weights)\n",
    "        self.model.eval()\n",
    "        self.preprocess = weights.transforms()\n",
    "        self.category_map = weights.meta[\"categories\"]\n",
    "\n",
    "    def __call__(self, images: daft.Series, shape: list[int, int, int]):\n",
    "        if len(images) == 0:\n",
    "            return []\n",
    "\n",
    "        # Convert the Daft Series into a list of Numpy arrays\n",
    "        data = images.cast(daft.DataType.tensor(daft.DataType.uint8(), tuple(shape))).to_pylist()\n",
    "\n",
    "        # Convert the numpy arrays into a torch tensor\n",
    "        images_array = torch.tensor(np.array(data)).permute((0, 3, 1, 2))\n",
    "\n",
    "        # Run the model, and map results back to a human-readable string\n",
    "        batch = self.preprocess(images_array)\n",
    "        prediction = self.model(batch).softmax(0)\n",
    "        class_ids = prediction.argmax(1)\n",
    "        prediction[:, class_ids]\n",
    "        return [self.category_map[class_id] for class_id in class_ids]\n",
    "\n",
    "\n",
    "# Filter out rows where the channel != 3\n",
    "df = df.where(df[\"image\"].apply(lambda img: img.shape[2] == 3, return_dtype=daft.DataType.bool()))\n",
    "\n",
    "df = df.with_column(\"predictions_lowres\", ClassifyImage(df[\"image_resized_small\"], [32, 32, 3]))\n",
    "df = df.with_column(\"predictions_highres\", ClassifyImage(df[\"image_resized_large\"], [256, 256, 3]))\n",
    "\n",
    "# Prune the results and write data back out as Parquet\n",
    "df = df.select(\n",
    "    \"filename\",\n",
    "    \"image_url\",\n",
    "    \"object\",\n",
    "    \"predictions_lowres\",\n",
    "    \"predictions_highres\",\n",
    ")\n",
    "df.write_parquet(\"my_results.parquet\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa593498-655b-4e41-87bc-05fe70a5ab66",
   "metadata": {},
   "source": [
    "# Now, take a look at your handiwork!\n",
    "\n",
    "Let's read the results of our distributed Daft job!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c82c38f5-dccd-484c-856a-92362e147412",
   "metadata": {},
   "outputs": [],
   "source": [
    "daft.read_parquet(\"my_results.parquet\").collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae3c32e5-a5c8-4db3-8f91-3beef30ca753",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
